import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from itertools import cycle

def heatmap(x, y, Z, trajectories, args=None):
    # 创建一个热力图
    plt.figure(figsize=(8, 6))
    plt.imshow(Z, extent=(args.xmin, args.xmax, args.ymin, args.ymax), origin='lower', cmap='viridis', interpolation='nearest')
    plt.colorbar(label="Value")
    if trajectories:
        line_styles = ['-', '--', '-.', ':']  
        colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']   
        style_cycle = cycle(line_styles)
        color_cycle = cycle(colors)
        for name, path in trajectories:
            plt.plot(path[0][args.start_epoch:args.start_epoch+9], path[1][args.start_epoch:args.start_epoch+9], linestyle=next(style_cycle), color=next(color_cycle), label=f'{name}', alpha=0.9)
        
    plt.title("Heatmap with Trajectories")
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.legend()
    plt.savefig(f'../image/{args.task_name}/{args.id}_heatmap_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}.png')  